
import numpy as np
from Tx import *
from eq import *
from channel import *
from simSweep import *
from timeit import default_timer as timer
from datetime import timedelta
import matplotlib.pyplot as plt
import sys
from neuralEQ import *
from neuralEQRNN import *
from simNeuralEQRNN import *
from simNeuralEQ import *
from FwdBwdNeuralEqV3 import *
import device
import os
import time

import torch
from torch import nn
import argparse
import pickle
from torchsummary import *


startTime = time.time()
'''****************************************************
Function define for saving and loading list.
It is used for processing fwdBwd output materials.
****************************************************'''
def saveList(fileName, l):
	with open(fileName, "wb") as fp:
		pickle.dump(l, fp)

def loadList(fileName):
	with open(fileName, "rb") as fp:
		out = pickle.load(fp)
	return out



'''*******************************************************
Function for post processing training input according to parameters(lossFn, mod, simpleDataTraining...)
*******************************************************'''
def trainSetPostProcess (lossFn, nrzNnOutOne, simpleDataTraining, fwdBwdProbTrain, chInTrain, mod):
	if lossFn == 'crossEntropy': # Cross Entropy loss
		#fwdBwdProbTrain = np.argmax(fwdBwdProbTrain,axis=1) 
		if nrzNnOutOne == True:
			print("nrzNnOutOne & crossEntropy do not work")
			sys.exit()
		if simpleDataTraining:
			fwdBwdProbTrain = np.zeros(len(chInTrain))
			if mod == 'nrz':
				fwdBwdProbTrain = np.where(chInTrain==1, 0, 1)
			elif mod == 'pam4':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k] == 1:
						fwdBwdProbTrain[k] = np.array(0)
					elif chInTrain[k] >= 0:
						fwdBwdProbTrain[k] = np.array(1)
					elif chInTrain[k] > -1:
						fwdBwdProbTrain[k] = np.array(2)
					else:
						fwdBwdProbTrain[k] = np.array(3)
						
		else:	
			fwdBwdProbTrain = np.argmax(fwdBwdProbTrain,axis=1)

	#@@ manualCrossEntropy case. Bypass for not simpleDataTraining case.
	elif lossFn == 'manualCrossEntropy': # Manual Cross Entropy
		if simpleDataTraining:
			#fwdBwdProbTrain = np.where(chInTrain==1, [0,1], [1,0])
			if nrzNnOutOne == True and mod == 'nrz':
				fwdBwdProbTrain = np.zeros((len(chInTrain),1))
			else:
				fwdBwdProbTrain = np.zeros((len(chInTrain),modNum))
			if mod == 'nrz':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k]==1:
						if nrzNnOutOne == True:
							fwdBwdProbTrain[k] = np.array([1])
						else:
							fwdBwdProbTrain[k] = np.array([1,0])
					else:
						if nrzNnOutOne == True:
							fwdBwdProbTrain[k] = np.array([0])
						else:
							fwdBwdProbTrain[k] = np.array([0,1])
			elif mod == 'pam4':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k]==1:
						fwdBwdProbTrain[k] = np.array([1,0,0,0])
					elif chInTrain[k] >= 0:
						fwdBwdProbTrain[k] = np.array([0,1,0,0])
					elif chInTrain[k] > -1:
						fwdBwdProbTrain[k] = np.array([0,0,1,0])
					else:
						fwdBwdProbTrain[k] = np.array([0,0,0,1])
						
		else:
			if nrzNnOutOne == True and mod == 'nrz':
				fwdBwdProbTrain = fwdBwdProbTrain[:,0]
		#print (fwdBwdProbTrain[0])
	#print (fwdBwdProbTrain.shape)
	#os.exit()

	#@@ mse case. Bypass if simpleDataTraining=0
	elif lossFn == 'mse':	# MSE loss
		if simpleDataTraining:
			if nrzNnOutOne == True and mod == 'nrz':
				fwdBwdProbTrain = np.zeros((len(chInTrain),1))
			else:
				fwdBwdProbTrain = np.zeros((len(chInTrain),modNum))
			if mod == 'nrz':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k]==1:
						if nrzNnOutOne == True:
							fwdBwdProbTrain[k] = np.array([1])
						else:
							fwdBwdProbTrain[k] = np.array([1,0])
					else:
						if nrzNnOutOne == True:
							fwdBwdProbTrain[k] = np.array([0])
						else:
							fwdBwdProbTrain[k] = np.array([0,1])
			elif mod == 'pam4':
				for k in range(len(fwdBwdProbTrain)):
					if chInTrain[k]==1:
						fwdBwdProbTrain[k] = np.array([1,0,0,0])
					elif chInTrain[k] >= 0:
						fwdBwdProbTrain[k] = np.array([0,1,0,0])
					elif chInTrain[k] > -1:
						fwdBwdProbTrain[k] = np.array([0,0,1,0])
					else:
						fwdBwdProbTrain[k] = np.array([0,0,0,1])
		else:
			if nrzNnOutOne == True and mod == 'nrz':
				#print(fwdBwdProbTrain[:20])
				fwdBwdProbTrain = fwdBwdProbTrain[:,0]
				fwdBwdProbTrain = fwdBwdProbTrain.reshape(-1,1)
				#print(fwdBwdProbTrain[:20])

		#@@ Mean value to 0 by substract 0.5. Is it helpful?
		fwdBwdProbTrain = (np.array(fwdBwdProbTrain)-0.5)#*1.99999
		#print(fwdBwdProbTrain[:20])
		#print(f'mod: {mod}')
		#print(f'modNum: {modNum}')
		#print(f'chInTrain: {chInTrain}')
		#print(f'fwdBwdProbTrain : {fwdBwdProbTrain}')
		#sys.exit()

	fwdBwdProbTrain = list(fwdBwdProbTrain)
	return fwdBwdProbTrain


'''****************************************************
Parser define.
name: used for naming result directory
config: input config file path
****************************************************'''
def parsing_def():
	parser      = argparse.ArgumentParser(description="This script is for generate verilog-A when port and initial value are given")
	parser.add_argument('-n', '--name', type=str, default='temp', required=True)
	parser.add_argument('-c', '--config', type=str, default='', required=True)
	args        = parser.parse_args()
	return args

args = parsing_def()

sys.path.insert(0, './config')
config_module = __import__('config_{}'.format(args.config))
config = config_module.config


'''*****************************************************
Parameter load from input config file. 
Config file is mandatory to run main.py.
These parameters contain 
	- which equalizer or training is perform
	- how equalizer operates
*****************************************************'''
#@@ Load parameters from config file
dataSizeTrain = config['Parameter']['dataSizeTrain']						#@@ data size(length) for training set.
dataSizeValid = config['Parameter']['dataSizeValid']						#@@ data size(length) for validation set.
dataSizeTest = config['Parameter']['dataSizeTest']							#@@ data size(length) for test set. It also used for normal equalizer input.
dataSizeGeneral = dataSizeTest
chSBR = config['Parameter']['chSBR'] #chSBR = [1.0,0.4]#,0.1**2]			#@@ chSBR is ISI to be added by channel.
eqSBR = config['Parameter']['eqSBR'] #eqSBR = [1.0,0.4]#,0.1**2]			#@@ eqSBR is used for equalizing calculation.
if config['Parameter']['forceTrainIn'] :
	batchSize = 1000
else:
	batchSize = int(dataSizeTrain/100)#config['Parameter']['batchSize']				#@@ mini-batch size for training
if config['Parameter']['batchSizeOvrd'] is not None:
	batchSize = config['Parameter']['batchSizeOvrd']
inSize = config['Parameter']['inSize']										#@@ input size of neural network
outSize = config['Parameter']['outSize']									#@@ output size of neural network
delay = int((inSize+1)/2)
snrTrain = config['NeuralEQ model gen and run']['snrTrain']
snrTrainList = config['NeuralEQ Training for Various SNR']['snrTrainList']	#@@ snr for training data set
snrValid= config['Parameter']['snrValid']									#@@ snr for validation data set
snrTest = config['Parameter']['snrTest']									#@@ snr for test data set
flagN = config['Parameter']['flagN']										#@@ noise adding on/off
numEpoch = config['Parameter']['numEpoch']									#@@ number of traning EPOCH
lossFn = config['Parameter']['lossFn']  									#@@ loss function selection. crossEntropy / manualCrossEntropy / mse	
lrInit = config['Parameter']['lrInit']										#@@ Initial learning rate
gamma = config['Parameter']['gamma']										#@@ Lerning rate decay constant
stepSize = config['Parameter']['stepSize']									#@@ Learning rate decay step size (unit: EPOCH)
weightDecay = config['Parameter']['weightDecay']
mod = config['Parameter']['mod'] 											#@@ modulation.  nrz or pam4 or pam8
simpleDataTraining = config['Parameter']['simpleDataTraining'] 				#@@ Training label selection (0: fwdBwd output / 1: simple TX data)
mismatchSNR = config['Parameter']['mismatchSNR']
nnSel = config['Parameter']['nnSel']
onTheFly = config['Parameter']['onTheFly']
forceTrainIn= config['Parameter']['forceTrainIn']
useFwdBwdNeuralEq = config['Parameter']['useFwdBwdNeuralEq']
hiddenStage = config['Parameter']['hiddenStage']
depth = config['Parameter']['depth']
pruneIter = config['Parameter']['pruneIter']
pruneRatio = config['Parameter']['pruneRatio']
nrzNnOutOne = False 
#earlyStop = True
earlyStop = config['Parameter']['earlyStop']
## No use
hiddenSize = 100
numLayers = 2
bidir = True
numEpochRnn = 100
batchSizeRnn = 1
seqLength = 100
##

# mod
if mod == 'nrz':
	modNum = 2
elif mod == 'pam4':
	modNum = 4
elif mod == 'pam8':
	modNum = 8
else:
	sys.exit('invalid modulation')

#@@ Print parameters for logging
print("")
print("----------------Simulation parameter----------------")
print(f"\
dataSizeTrain: {dataSizeTrain}\n\
dataSizeValid: {dataSizeValid}\n\
dataSizeTest: {dataSizeTest}\n\
chSBR: {chSBR}\n\
eqSBR: {eqSBR}\n\
batchSize: {batchSize}\n\
inSize: {inSize}\n\
outSize: {outSize}\n\
delay: {delay}\n\
snrTrain: {snrTrain}\n\
snrValid: {snrValid}\n\
flagN: {flagN}\n\
numEpoch: {numEpoch}\n\
hiddenSize: {hiddenSize}\n\
numLayers: {numLayers}\n\
bidir: {bidir}\n\
numEpochRnn: {numEpochRnn}\n\
batchSizeRnn: {batchSizeRnn}\n\
seqLength: {seqLength}\n\
mod: {mod}\n\
lrInit: {lrInit}\n\
gamma: {gamma}\n\
stepSize: {stepSize}\n\
snrTrainList: {snrTrainList}\n\
onTheFly: {onTheFly}\n\
forceTrainIn: {forceTrainIn}\n\
mismatchSNR: {mismatchSNR}\n\
earlyStop: {earlyStop}\n\
useFwdBwdNeuralEq: {useFwdBwdNeuralEq}\n\
", flush=True)
print("----------------------------------------------------", flush=True)
print("")

if (dataSizeTrain <= batchSize):
	print("Error: dataSizeTrain must be larger than batchSize")
	sys.exit()

#@@ Seed fix. (Note that it doesn't fix the randomness of neural network parameters) 
np.random.seed(1)




'''*****************************************************
Equalizer running. (It doesn't include traning part)
	e.g) FIR(FFE), DFE, FIR+DFE, VITERBI, FWDBWD, FWD, nEQ 
It can be specified which EQ is selected to be run. (In config file)
snrList can be defined and simulation will run for each snr in snrList.
Parameters for specific EQ is defined here. (ffeTapNum, dfeTapNum ...)
If you have trained parameters of nEQ, you can run the nEQ on this part.
*****************************************************'''

tx = Tx(mod=mod)
chInGeneral = tx.run(dataSizeGeneral)
#********************************#
# Switching which EQ is ON 
#********************************#
snrList = config['Conventional EQ']['snrList']
sweep = simSweep(chSbr=chSBR, eqSbr=eqSBR, snrList=snrList, originData=chInGeneral,mod=mod,flagN=flagN, stateGen=True)
#sweep = simSweep(chSbr=chSBR, eqSbr=eqSBR, snrList=[snrTest], originData=chInTest,mod=mod,flagN=flagN)
firBerList = None
dfeBerList = None
firDfeBerList = None
if int(config['Conventional EQ']['FIR_ber']):
	firBerList = sweep.fir(ffeTapNum=24)
	print("",flush=True)
if int(config['Conventional EQ']['DFE_ber']):
	dfeBerList = sweep.dfe(dfeTapNum=None)
	print("",flush=True)
if int(config['Conventional EQ']['FIRDFE_ber']):
	firDfeBerList = sweep.firDfe(ffeTapNum=24, ffeMaxTapNum=14, dfeTapNum=1) #ffeTapNum
	print("",flush=True)
if int(config['Conventional EQ']['VITERBI_ber']):
	sweep.viterbiOverlap(blockSizeList=[100]) #default
	#sweep.viterbiOverlap(blockSizeList=[20])
	print("",flush=True)
if int(config['Conventional EQ']['FWDBWD_ber']):
	fwdBwdBerList, fwdBwdProb = sweep.fwdBwd(fwdBwdLen=10)
	print("",flush=True)
if int(config['Conventional EQ']['FWD_ber']):
	fwdBerList, fwdProb = sweep.fwd(fwdLen=5)
	print("",flush=True)
if int(config['Conventional EQ']['nnFWDBWD_ber']):
#	for snr in snrList:
#		if os.path.exists('./results/NRZ_20_0dB_train/nEQ_nrz_%ddB.pt'%snr):
#			print("nEQ trained with %ddB"%snr, flush=True)
#			nEQLoad = torch.load('./results/NRZ_20_0dB_train/nEQ_nrz_%ddB.pt'%snr)
#			nnFwdBwdBerList = sweep.nnFwdBwd(neuralEQ=nEQLoad, lossFn=nn.MSELoss(), batchSize=batchSize, inSize=inSize, outSize=outSize, delay=delay)
#			print("",flush=True)
#		else:
#			print("")
#			print("Not exists nEQ_nrz_%ddB.pt"%snr)
#			print("")
	#for snr in snrList:
	if 1:
		#nnSnr=38
		#if os.path.exists('./results/3.ce/nEQ_nrz_%ddB.pt'%nnSnr):
		#if os.path.exists('./results/4/nEQ_nrz_%ddB.pt'%nnSnr):
		#if os.path.exists('./results/10.final/nEQ_pam4_%ddB.pt'%nnSnr):
		paramFile = './results/fw_isi0_in12n64/nEQ_pam4_20dB_simp1_mse.pt'
		if os.path.exists(paramFile):
			#print("nEQ trained with %ddB"%nnSnr, flush=True)
			#nEQLoad = torch.load('./results/3.ce/nEQ_nrz_%ddB.pt'%nnSnr)
			#nEQLoad = torch.load('./results/4/nEQ_nrz_%ddB.pt'%nnSnr)
			#nEQLoad = torch.load('./results/10.final/nEQ_pam4_%ddB.pt'%nnSnr)
			nEQLoad = torch.load(paramFile)
			nEQLoad = nEQLoad.to(device.device)
			nnFwdBwdBerList = sweep.nnFwdBwd(neuralEQ=nEQLoad, lossFn=nn.MSELoss(), batchSize=10000, inSize=inSize, outSize=outSize, delay=delay)
			print("",flush=True)
		else:
			print("")
			print("Not exists nEQ_nrz_%ddB.pt"%nnSnr)
			print("")


'''*****************************************************
Tx and channel define.
Tx generates random data according to modulation.
Channel adds ISI and noise.
Channel is defined 3 times for training, validation and test.
Note that test sets are used for both nEQ test and normal equalizer.
*****************************************************'''


#@@ Valid sequence for on training 
chInValid = tx.run(dataSizeValid)
ch2 = Channel(sbr=chSBR, snr=snrValid)
chOutValid = ch2.run(chIn = chInValid, flagN=flagN)


#@@ Test sequence for final evaluation 
chInTest = tx.run(dataSizeTest)
ch3 = Channel(sbr=chSBR, snr=snrTest)
chOutTest = ch3.run(chIn = chInTest, flagN=flagN)

if 0:	# for debug
	plt.plot(chInTest,'-o',label='chIn')
	plt.plot(chOutTest,'-*',label='chOut')
	#plt.show()


'''*********************************************
Neural EQ training for various SNR and misc.
Training is performed for snrTrainList, lossFn, and simpleDataTraining.
If you want to reduce or increase sweep cases, modify here.(It's not controlled by config file now. TODO )
*********************************************'''


for simpleDataTraining in config['NeuralEQ Training for Various SNR']['simpleDataTraining']:
	for lossFn in config['NeuralEQ Training for Various SNR']['lossFn']:
		print ("")
		print ("")
		print (f"simpleDataTraining: {simpleDataTraining}")
		print (f"lossFn: {lossFn}")
	
		if int(config['NeuralEQ Training for Various SNR']['nEQ_training_on']):
			for idx, snrTrain in enumerate(snrTrainList):
				print("")
				print(f"trainIdx: {idx} \t snrTrain: {snrTrain}")
				print("")
				#@@ Neural network definition
				#@@ nrzNnOutOne means network output size is set to 1 for NRZ. But it seems not work.
				if useFwdBwdNeuralEq:
					nEQ = FwdBwdNeuralEq(hiddenStage, inSize, delay, depth, batchSize, mod)
				elif (nrzNnOutOne == True) and (mod=='nrz'):
					nEQ = neuralEQ(inSize=inSize, outSize=outSize*1, mod=mod) # seperate ver.
				else:
					nEQ = neuralEQ(inSize=inSize, outSize=outSize*modNum, mod=mod, nnSel=nnSel)
				nEQ = nEQ.to(device.device)


				#@@ IT IS NOT USED NOW.
				#@@ Loss function definition. 
				#lossFn = nn.MSELoss()
				#lossFn = nn.BCEWithLogitsLoss()
				#lossFn = nn.BCELoss()

				#@@ Optimizer definition.
				#@@ Adam is selected.
				#opt = torch.optim.SGD(nEQ.parameters(), lr=lrInit)
				opt = torch.optim.Adam(nEQ.parameters(), lr=lrInit, weight_decay=weightDecay)#1e-5)

				#@@ Scheduler definition.
				#@@ gamma=1 means no learning rate change.
				sch = torch.optim.lr_scheduler.StepLR(opt, step_size=stepSize, gamma=gamma)
				print("")
				print("----------------NeuralNet parameter----------------")
				print(nEQ)
				print(lossFn)
				print(opt)
				print("---------------------------------------------------")
				print("")
			
				summary(nEQ, (batchSize,inSize), batch_size=batchSize, device=device.device)
			
		
		#		chInTrain = tx.run(dataSizeTrain)
		#		ch = Channel(sbr=chSBR, snr=snrTrain)
		#		chOutTrain = ch.run(chIn = chInTrain, flagN=flagN)
				
				#########################################
				#### Test sequence for on training  #####
				#########################################
				#chInTest = np.array([],dtype=np.int)
				#chInTest = np.append(chInTest, np.random.randint(2, size=dataSizeTest))
				#chInTest = 2 * chInTest - 1
		#		chInValid = tx.run(dataSizeValid)
		#		ch2 = Channel(sbr=chSBR, snr=snrTrain)
		#		chOutValid = ch2.run(chIn = chInValid, flagN=flagN)
				
		
				#simNEQ = simNeuralEQ(txDataTrain=chInTrain, rxDataTrain=chOutTrain, txDataTest=chInTest, rxDataTest=chOutTest, neuralEQ=nEQ)
				#print ([chOutTrain])
		##########################################################################################Need to
				

				#@@ Check if pre-simulated fwdBwd is exists
				#@@ If corresponding file(snr, sbr, mod ...) exists, just load from file.
				#@@ If not, run fwdBwd algorithm 
				if config['Parameter']['trainRealChannel']:
					Legacy =False
				else:
					Legacy = True
				if forceTrainIn:
					fwdBwdProbFileName = 'caching_data/probNew_less09.list'
					fwdBwdProbChOutFileName = 'caching_data/chOutNew_less09.list'
					fwdBwdProbChInFileName = 'caching_data/chInNew_less09.list'
				else:
					if (mismatchSNR is not None):
						fwdBwdProbFileName = './caching_data/%s_fwdBwdProb_size%d_%s_snr%ddB.list'%(mod,dataSizeTrain,eqSBR,snrTrain+mismatchSNR)
						fwdBwdProbChOutFileName = './caching_data/%s_fwdBwdProbChOut_size%d_%s_snr%ddB.list'%(mod,dataSizeTrain,eqSBR,snrTrain+mismatchSNR)
						fwdBwdProbChInFileName = './caching_data/%s_fwdBwdProbChIn_size%d_%s_snr%ddB.list'%(mod,dataSizeTrain,eqSBR,snrTrain+mismatchSNR)
					else:
						fwdBwdProbFileName = './caching_data/%s_fwdBwdProb_size%d_%s_snr%ddB.list'%(mod,dataSizeTrain,eqSBR,snrTrain)
						fwdBwdProbChOutFileName = './caching_data/%s_fwdBwdProbChOut_size%d_%s_snr%ddB.list'%(mod,dataSizeTrain,eqSBR,snrTrain)
						fwdBwdProbChInFileName = './caching_data/%s_fwdBwdProbChIn_size%d_%s_snr%ddB.list'%(mod,dataSizeTrain,eqSBR,snrTrain)
				if (os.path.exists(fwdBwdProbFileName)):
					#@@ Existing case. Load from the file.
					print("")
					print("File(%s) exists, load from file"%fwdBwdProbFileName)
					print("")
					fwdBwdProbTrain = loadList(fwdBwdProbFileName)
					#fwdBwdProbChInTrain = loadList(fwdBwdProbChInFileName)
					chOutTrain = loadList(fwdBwdProbChOutFileName)
					chInTrain = loadList(fwdBwdProbChInFileName)
				else:

					#@@ Train sequence gen
					chInTrain = tx.run(dataSizeTrain)
					if (mismatchSNR is not None):
						ch = Channel(sbr=chSBR, snr=snrTrain+mismatchSNR)
					else:
						ch = Channel(sbr=chSBR, snr=snrTrain)
					chOutTrain = ch.run(chIn = chInTrain, flagN=flagN)

				
					if simpleDataTraining == 0:
						#@@ No existing case. run fwdBwd
						print("")
						print("File(%s) no exists, excute fwdBwd"%fwdBwdProbFileName)
						print("")
						#@@ Running fwdBwd with specified channel output, chOutTrain.
						sweepForTrain = simSweep(chSbr=chSBR, eqSbr=eqSBR, snrList=[snrTrain], originData=chInTrain, chOutList=[chOutTrain], mod=mod, stateGen=True)
						fwdBwdBerListTrain, fwdBwdProbTrain = sweepForTrain.fwdBwd(fwdBwdLen=inSize)
						saveList(fwdBwdProbFileName, fwdBwdProbTrain)
						saveList(fwdBwdProbChOutFileName, chOutTrain)
						saveList(fwdBwdProbChInFileName, chInTrain)
			
				if simpleDataTraining == 0:
					fwdBwdProbTrain = np.array(fwdBwdProbTrain)
				#print (fwdBwdProbTrain.shape)
				#@@ Post-processing fwdBwd output according to loss function.
				#@@ If simpleDataTraining=1, forcing fwdBwd output to simple TX data.
				#@@ fwdBwdProbOut = (modNum)*dataLen
				#@@ TxData = (1)*dataLen. 
				#@@ crossEntropy = (1)*dataLen
				#@@ manualCrossEntropy = (modNum)*dataLen
				#@@ mse = (modNum)*dataLen
				#@@ According to format above, it need to be adjusted.

				#@@ crossEntropy case. Force to argmax value.
				#print(f"len(fwdBwdProbTrain): {len(fwdBwdProbTrain)}")
				#print(f"len(chInTrain): {len(chInTrain)}")
				if lossFn == 'crossEntropy': # Cross Entropy loss
					#fwdBwdProbTrain = np.argmax(fwdBwdProbTrain,axis=1) 
					if nrzNnOutOne == True:
						print("nrzNnOutOne & crossEntropy do not work")
						sys.exit()
					if simpleDataTraining:
						fwdBwdProbTrain = np.zeros(len(chInTrain))
						if mod == 'nrz':
							fwdBwdProbTrain = np.where(chInTrain==1, 0, 1)
						elif mod == 'pam4':
							for k in range(len(fwdBwdProbTrain)):
								if chInTrain[k] == 1:
									fwdBwdProbTrain[k] = np.array(0)
								elif chInTrain[k] >= 0:
									fwdBwdProbTrain[k] = np.array(1)
								elif chInTrain[k] > -1:
									fwdBwdProbTrain[k] = np.array(2)
								else:
									fwdBwdProbTrain[k] = np.array(3)
									
					else:	
						fwdBwdProbTrain = np.argmax(fwdBwdProbTrain,axis=1)

				#@@ manualCrossEntropy case. Bypass for not simpleDataTraining case.
				elif lossFn == 'manualCrossEntropy': # Manual Cross Entropy
					if simpleDataTraining:
						#fwdBwdProbTrain = np.where(chInTrain==1, [0,1], [1,0])
						if nrzNnOutOne == True and mod == 'nrz':
							fwdBwdProbTrain = np.zeros((len(chInTrain),1))
						else:
							fwdBwdProbTrain = np.zeros((len(chInTrain),modNum))
						if mod == 'nrz':
							for k in range(len(fwdBwdProbTrain)):
								if chInTrain[k]==1:
									if nrzNnOutOne == True:
										fwdBwdProbTrain[k] = np.array([1])
									else:
										fwdBwdProbTrain[k] = np.array([1,0])
								else:
									if nrzNnOutOne == True:
										fwdBwdProbTrain[k] = np.array([0])
									else:
										fwdBwdProbTrain[k] = np.array([0,1])
						elif mod == 'pam4':
							for k in range(len(fwdBwdProbTrain)):
								if chInTrain[k]==1:
									fwdBwdProbTrain[k] = np.array([1,0,0,0])
								elif chInTrain[k] >= 0:
									fwdBwdProbTrain[k] = np.array([0,1,0,0])
								elif chInTrain[k] > -1:
									fwdBwdProbTrain[k] = np.array([0,0,1,0])
								else:
									fwdBwdProbTrain[k] = np.array([0,0,0,1])
									
					else:
						if nrzNnOutOne == True and mod == 'nrz':
							fwdBwdProbTrain = fwdBwdProbTrain[:,0]
					#print (fwdBwdProbTrain[0])
				#print (fwdBwdProbTrain.shape)
				#os.exit()

				#@@ mse case. Bypass if simpleDataTraining=0
				elif lossFn == 'mse':	# MSE loss
					if simpleDataTraining:
						if nrzNnOutOne == True and mod == 'nrz':
							fwdBwdProbTrain = np.zeros((len(chInTrain),1))
						else:
							fwdBwdProbTrain = np.zeros((len(chInTrain),modNum))
						if mod == 'nrz':
							for k in range(len(fwdBwdProbTrain)):
								if chInTrain[k]==1:
									if nrzNnOutOne == True:
										fwdBwdProbTrain[k] = np.array([1])
									else:
										fwdBwdProbTrain[k] = np.array([1,0])
								else:
									if nrzNnOutOne == True:
										fwdBwdProbTrain[k] = np.array([0])
									else:
										fwdBwdProbTrain[k] = np.array([0,1])
						elif mod == 'pam4':
							for k in range(len(fwdBwdProbTrain)):
								if chInTrain[k]==1:
									fwdBwdProbTrain[k] = np.array([1,0,0,0])
								elif chInTrain[k] >= 0:
									fwdBwdProbTrain[k] = np.array([0,1,0,0])
								elif chInTrain[k] > -1:
									fwdBwdProbTrain[k] = np.array([0,0,1,0])
								else:
									fwdBwdProbTrain[k] = np.array([0,0,0,1])
					else:
						if nrzNnOutOne == True and mod == 'nrz':
							#print(fwdBwdProbTrain[:20])
							fwdBwdProbTrain = fwdBwdProbTrain[:,0]
							fwdBwdProbTrain = fwdBwdProbTrain.reshape(-1,1)
							#print(fwdBwdProbTrain[:20])
	
					#@@ Mean value to 0 by substract 0.5. Is it helpful?
					fwdBwdProbTrain = (np.array(fwdBwdProbTrain)-0.5)#*1.99999
					#print(fwdBwdProbTrain[:20])
					#print(f'mod: {mod}')
					#print(f'modNum: {modNum}')
					#print(f'chInTrain: {chInTrain}')
					#print(f'fwdBwdProbTrain : {fwdBwdProbTrain}')
					#sys.exit()
		
				fwdBwdProbTrain = list(fwdBwdProbTrain)
				if 0: ############################################################ if 0
					print(f'chInLen: {len(chInTrain)}, chOutLen: {len(chOutTrain)}')
					print(f'chInTrain: {chInTrain[:20]}')
					print(f'chOutTrain: {chOutTrain[:20]}')
					print(f'fwdBwdProbTrain: {fwdBwdProbTrain[:20]}')
					#print(np.array(fwdBwdProbTrain).shape)
					#print(chInTrain.shape)
		
				######################
				#### Run neuralEQ ####
				######################
				#@@ Define simNeuralEQ(class) with training set (fwdBwd or simpleTx), validation set (always Tx data)
				#@@ txData means yi, rxData means xi
				if 0:
					print(f"fwdBWdProbTrain: {fwdBwdProbTrain}")
				simNEQ = simNeuralEQ(txDataTrain=fwdBwdProbTrain, rxDataTrain=chOutTrain, txDataTest=chInValid, rxDataTest=chOutValid, neuralEQ=nEQ, mod=mod)
				trainLossList = []
				validLossList = []
				trainBerList = []
				validBerList = []
				pastValidBer = None
				for k in range(numEpoch):

					if onTheFly:
						#@@ Train sequence gen
						chInTrainOnTheFly = tx.run(dataSizeTrain)
						chOnTheFly = Channel(sbr=chSBR, snr=snrTrain)	
						chOutTrainOnTheFly = chOnTheFly.run(chIn = chInTrainOnTheFly, flagN=flagN)
						chInTrainOnTheFly = trainSetPostProcess ('mse', False, 1, chInTrainOnTheFly, chInTrainOnTheFly, mod)
						loss = simNEQ.trainNeuralEQ(lossFn, opt, batchSize=batchSize, inSize=inSize, outSize=outSize, delay=delay, rxDataTrainNew=chOutTrainOnTheFly, txDataTrainNew=chInTrainOnTheFly)
					else:
						loss = simNEQ.trainNeuralEQ(lossFn, opt, batchSize=batchSize, inSize=inSize, outSize=outSize, delay=delay)




					#@@ Run training with batchSize, inSize, outSize, delay. (Last three parameters are need for data curating in simNEQ)
					sch.step()
					trainLossList.append(loss)
					#trainBerList.append(berTrain)
					#@@ Print loss every 10 EPOCHs
					#if (k%10==0):
					#	#print(f"trainloss: {loss:e}, trainber: {berTrain:e},  epoch:{k}/{numEpoch}", flush=True)
					print(f"trainloss: {loss:e},  lr: {sch.get_last_lr()}, epoch:{k}/{numEpoch}", flush=True)
				
					#simNEQ.txDataTrain = txDataTrain
					#@@ Run for validation set every 10 EPOCHs
					if (k%10==0):
						#@@ ~New will override init value. Valid data is fixed even if train data is changed with snrTrain. It is intended.
						validLoss, berValid = simNEQ.evalNeuralEQ(lossFn, batchSize=batchSize, inSize=inSize, outSize=outSize, delay=delay, rxDataTestNew=chOutValid, txDataTestNew=chInValid)
						validLossList.append(validLoss)
						validBerList.append(berValid)
						print(f"validloss: {validLoss:e}, validber: {berValid:e}, epoch:{k}/{numEpoch}", flush=True)
						#@@ If past valid BER is better than current BER, which means overfitting, it stops training.
						if (pastValidBer is not None):
							if pastValidBer < berValid:
								print(f"Overfitting is detected!")
								if (earlyStop):
									break
						pastValidBer = berValid
						print("")
		
				#@@ After training, neural network parameters are saved for each snrTrain.
				torch.save(nEQ, './results/%s/nEQ_%s_%ddB_simp%d_%s.pt'%(args.name,mod,snrTrain,simpleDataTraining,lossFn))
				
				#testLoss, berTest = simNEQ.evalNeuralEQ(lossFn, batchSize=batchSize, inSize=inSize, outSize=outSize, delay=delay)
				#simNEQ.txDataTrain = 

				#@@ Finally, running nEQ with test set.
				testLoss, berTest = simNEQ.evalNeuralEQ(lossFn, batchSize=batchSize, inSize=inSize, outSize=outSize, delay=delay, rxDataTestNew=chOutTest, txDataTestNew=chInTest )
				berTestList = [berTest]
				#testLossList.append(testLoss)
				#testBerList.append(berTest)
				print (f"simpleDataTraining: {simpleDataTraining}")
				print (f"lossFn: {lossFn}")
				print(f"Finaltestloss: {testLoss:e}, testber: {berTest:e}, epoch:{k}/{numEpoch} @ simpleDataTraining:{simpleDataTraining}, lossFn:{lossFn}", flush=True)
				
				
				
				#print("")
				#print("----------------NeuralNet Coef----------------")
				#for name, param in nEQ.named_parameters():
				#	print(f"name: {name} params:\n{param}")
				#print("----------------------------------------------")
				#print("")
				
				#simNEQ.evalNeuralEQ(lossFn, batchSize=batchSize, inSize=inSize, outSize=outSize, delay=delay, rxDataTestNew=chOutTestFinal, txDataTestNew=chInTestFinal )
				
				if 0:
					plt.figure(0)
					plt.plot(trainLossList,'-', label='trainloss')
					#plt.plot(testLossList,'-', label='testloss')
					plt.grid(True)
					#plt.yscale('log')
					#plt.ylim([1e-9, 1])
					plt.xlabel('epoch')
					plt.ylabel('loss')
					plt.legend(loc='best')
					#plt.show()
					plt.savefig('./results/%s/loss_%s_%ddB.png'%(args.name,mod,snrTrain))
					#plt.cla()
					
					if 1:
						plt.figure(1)
						#plt.plot(trainBerList,'-', label='trainber')
						plt.plot(validBerList,'-', label='validber')
						if (firBerList is not None):
							plt.plot(firBerList*len(validBerList),'--',label='firber')
						if (dfeBerList is not None):
							plt.plot(dfeBerList*len(validBerList),'--',label='dfeber')
						plt.plot(berTestList*len(validBerList),'--',label='nnFinalBer')
						#print (dfeBerList*len(trainBerList))
						plt.grid(True)
						plt.yscale('log')
						plt.ylim([1e-4, 1])
						plt.xlabel('epoch')
						plt.ylabel('ber(accuracy)')
						#plt.show()
						plt.legend(loc='best')
						plt.savefig('./results/%s/ber_%s_%ddB.png'%(args.name,mod,snrTrain))
						#plt.cla()
		
		
		if 0:
			plt.grid(True)
			plt.legend(loc='best')
			plt.show()
		
		
		#########################
		#### Debug neural EQ ####
		#########################
		#simNEQ = simNeuralEQ(txDataTrain=range(0,10), rxDataTrain=range(0,10), txDataTest=range(0,10), rxDataTest=range(0,10), neuralEQ=nEQ)
		#rxDataSet, txDataSet=simNEQ.curatingData(range(0,10), range(10,20), inSize, outSize, batchSize, delay=delay)
		#
		#print(rxDataSet)
		#print(txDataSet)
		#
		#simNEQ.trainNeuralEQ(lossFn, opt, batchSize=batchSize, inSize=inSize, outSize=outSize, delay=delay)


timeSim = (time.time()-startTime)/60. # Unit: minuite
print(f"Total simulation time: {timeSim} mins")
